{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# encoding: utf-8\n",
    "\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib\n",
    "# 在 jupyter notebook 中使用 matplotlib\n",
    "%matplotlib inline\n",
    "# matplotlib.use('Agg')  # 服务器端使用 matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "load_norm shape: (105192, 1)\n",
      "[[-0.85610657]\n",
      " [-0.99451862]\n",
      " [-1.14643429]\n",
      " [-1.2432102 ]\n",
      " [-1.32423189]\n",
      " [-1.3366102 ]\n",
      " [-1.23308249]\n",
      " [-1.17906803]\n",
      " [-1.08454272]\n",
      " [-1.03165356]]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYcAAAD8CAYAAACcjGjIAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAAIABJREFUeJztnXd4VFX6x78voYaaQOglIJGulAgooiJVcMW6oi5iWVlddC27+oO14Mqq6K6r6+qqrKBYEVERBSkiRVFKQEqoCRAgQCBAgFBC2vn9MXfCzfS595xbZt7P8+TJzJlzzzk3d3Lfe95KQggwDMMwjJ4qdi+AYRiGcR4sHBiGYRg/WDgwDMMwfrBwYBiGYfxg4cAwDMP4wcKBYRiG8YOFA8MwDOMHCweGYRjGDxYODMMwjB9V7V6AURo1aiRSU1PtXgbDMIyrWLt27REhREq4fq4VDqmpqcjIyLB7GQzDMK6CiPZE0o/VSgzDMIwfLBwYhmEYP1g4MAzDMH6wcGAYhmH8YOHAMAzD+MHCgWEYhvGDhQPDMAzjBwsHhmEYF7Ax9zjeXJJt2XyuDYJjGIaJF4pKynDdGysAAOMGtLdkTt45MAzDOJzdR05bPicLB4ZhGIczaspKy+dk4cAwDONwTpwtsXxOFg4MwzCMHywcGIZhGD9YODAMwzB+sHBgGIZh/GDhwDBMzJM6fi4em7ne7mW4ChYODMPEBV+u22/3ElwFCweGYWKaTbkn7F6CK2HhwDCM5czdeBDDXltuyVzPzMm0ZJ5Yg4UDwzCWM+6TddiWV4jteYV2L8V1HD11zpJ5wgoHIppGRIeJKFPXdgsRbSaiciJK9+k/gYiyiWg7EQ3VtQ/T2rKJaLyuvS0RrdLaPyOi6rJOjmEYZ3OutEzZ2KVl5fj7t1vw697jyuawg++3HrJknkh2Du8DGObTlgngRgCV9oVE1BnAKABdtGP+S0QJRJQA4E0A1wDoDOA2rS8AvATgVSFEewAFAO41dioMw7iN/QVnlY392vdZePen3crGt4uTZ0stmSescBBCLAdwzKdtqxBie4DuIwHMEEKcE0LsBpANoLf2ky2E2CWEKAYwA8BIIiIAVwOYpR0/HcD1hs+GYRhX8cDH63DsdLGSsTfkxtaOwYuAsGQe2TaHFgD26d7nam3B2hsCOC6EKPVpZxgmTigsUpNU7sesI0rGtZtTRQ7ZOTgJIhpLRBlElJGfn2/3chiGYVBSVo7U8XOROn4uCotKsO/YGaXzvf6DNdXgZAuH/QBa6d631NqCtR8F0ICIqvq0B0QIMUUIkS6ESE9JSZG6cIZhYh/vTVwm+kI83Z5diP4vL5E6vl3IFg5zAIwiohpE1BZAGoDVANYASNM8k6rDY7SeI4QQAJYAuFk7fgyAryWviWEYhomSSFxZPwXwC4AORJRLRPcS0Q1ElAvgUgBziWgBAAghNgOYCWALgPkAxgkhyjSbwoMAFgDYCmCm1hcA/g/AY0SUDY8NYqrcU2QYxg3c/d5qjJm22u5lMBpVw3UQQtwW5KOvgvR/HsDzAdrnAZgXoH0XPN5MDMPEITMz9uHxoR2xZLt1dsSycoGEKmTZfG7EVQZphmFijzeX7IRHwyyPj1buCfn56WI5Hj+HThZhzvoDUsZyGiwcGIaxnaW6XcM3G8zfbJ+aHTqf0s/Zctxcb//fSryxxBrvIath4cAwjO0UlZxPo/HQp7+irFxtoNdRSYF3eSeKpIzjRFg4MAxjO0UKcywFolxAuipLz6lz1gSqqYSFA8MwtvPoZxsqvVdtKn56diaen7vV9DiniwMLtZfnbzM9tpfDhfbsTlg4MAwTl0xboS4p3/p98vI6vbciR9pY0cDCgWEYS9l7NHx6iR+2HVa+DpVmjXKJKitZxvNoYeHAMIyljHkvfKCbPiWFG9ly4KS0sTbYVOaUhQPDuBS3Gj3PRBBjoFLlYwWKna0sgYUDw7iQn7OPoOvEBfhJQVrq7XmF6PLMfPyY5Yk92Hv0DE5KTKtdcCb8WAdj2EXULbBwYBgXsmq3p/7WmpxjYXpGz58/X4/TxWUYPXU1/rM4C1f8YwkuenahtPGLS8uljcWog4UDwzBBeWXRDruXwNgECweGcSEqVdoKY8Ncg7eAz9fr9+Pn7COYsXpv1GNsdHmZ0rBZWRmGcR5HT50DAGTska9W2izR08ateNNiPDxjfUXbqN6toxrjujdWIGfyCKnrshLeOTCMC/l4ledJdkX2UZtXwtiBytQfXlg4MAzjSopKypQl6HO6aq24TL1Rn9VKDMO4jvJygY5PzwcA21Q3g/61zJZ5rYJ3DgzDuA7VWVxFGJO/EALZh0+FHWd+5kFZS7IcFg4Mw1RQYoG6wg3IUivd/9E6OQPZAAsHhmEqSHvyO7uXYJpzEnYVo6etkrAS8yzcnGfb3CwcGIZxHa8vDl6ac9baXNPj7zt21vQYMth6sNC2uVk4MAzjOt5etrPi9eYDlbOWqi4xCljnzSRjF2QUFg5MzCOEwIrsI5b4hjPWM+L1nyq9P3DcnUn73lySjR2HKu8Ulm7Pt2k1LByYOOCbjQdxx7urKgLHVNNr0iI8NnN9+I5MSMoN7gD0uwq3UFJWjn8s2I4b3lxRqX3LwcDR6lY857BwYGKe/QUe/fG+gvAVyIxyprgUqePn4oGP1uLo6WJ8uW6/knke/GSdkkysVvHnmRvCd9LoPHF+3O32ihyUsZaFAxM/KLzPPPjJrwCA7zLVeZecKS7FtxsP4pa3f1E2h2q+WBe5sbiopByrdztHEBacLq54LfOr9M6ynY70EmPhwMQ8RGrH355X6Pqax5FQKLHgT6QEMi6rrqkcrMJej0mLlMz34nfbKl5HulNS/Z0GWDgwjGmGvrbc7iVYQrdnF6KoxD7vGS/PzNksbaxAQX/lEdygTxWpKdFaLoAPf8lRMna0hBUORDSNiA4TUaauLZmIFhFRlvY7SWsnInqdiLKJaCMR9dQdM0brn0VEY3TtvYhok3bM60RWyEQmHnG79nprEOOklZyzWice4G4g0w6xM98/BUaVCG5BR0+fi3iOvUejs3U9/bU84WeGSHYO7wMY5tM2HsBiIUQagMXaewC4BkCa9jMWwFuAR5gAmAigD4DeACZ6BYrW5z7dcb5zMYwpDh63J6Bp3ia5eXUe0uwaqvg0goI2oe6b2/JOInX8XL+4A9kc1en+zVJa5i9oqkh+PB3ymjsT9IUVDkKI5QB8rUIjAUzXXk8HcL2u/QPhYSWABkTUDMBQAIuEEMeEEAUAFgEYpn1WTwixUngeBz7QjcUwUpj+yx4A1ucN+uPHcvPq5J1U678/4ctNYfscDrGGhZsPAQDmSzTKU4Ctw/Ez8mwfZ4r91WSB5jRDUUng793hwiI8K1FFJhujNocmQgjvY1EegCba6xYA9un65WptodpzA7QzjHTeW5Fj+ZxFJWUBVRduZfXuAruXIJV73l/j12aVYvuprzLx/s851kxmANMGae2J3xJ1LhGNJaIMIsrIz7cvcpBhIqXj0/Mx8JVlOFOsxoBpNX/9KvzuQiaqb9TBPJPCUSrBdWzhlkMhP7c7xsOocDikqYSg/fb68e0H0ErXr6XWFqq9ZYD2gAghpggh0oUQ6SkpKQaXzjDWU2zSkOsk4WLlTihS2XD8jDw7RCi253nSWwx5Vb2H2q/7jiufIxRGhcMcAF6PozEAvta136l5LfUFcEJTPy0AMISIkjRD9BAAC7TPThJRX81L6U7dWAwTM5jVY090iAcLAAx8JXIDa96JIhw5Fblnjy+ROi/e8a68FNtzNhwI+tlJC2M93vwheOZZK4jElfVTAL8A6EBEuUR0L4DJAAYTURaAQdp7AJgHYBeAbAD/A/BHABBCHAMwCcAa7ec5rQ1an3e1Y3YCcF6oIMOYxaR65PMQaahLIzS0L9+Rj7PFZSgsKrFMZdH3xcVI//v3yufZfMDj5hut26gv5eUCT8zaKGNJprGiTnQowtaQFkLcFuSjgQH6CgDjgowzDcC0AO0ZALqGWwfDuBnZ7pF62j/5Xdg6ylmHCnHntNW48sIULNuRj78MuRAPXp2mblGSiNbm8OWv5mo5lIURmlaaAQ6esDe7LEdIM4wFyHS/NMJJLaJ32Q6PI8fcTWorjBWXlvtlR3WS3SQY4W7+VhqJI6lRrRIWDoyt5J0osrQU4to99iRy22NS3RGO1PFzkXPktNI5omHqT7sxWZczCADOBfH3D0W0G66zAeIWouG7TLmBi16stFXIgoUDYys3/ncFxn641rBLYbTc9JZ7M5qGY30U3i1nFT7F/+eHbCzc4i/wjbilFkSx4zp0sgjvLN8V/SQ6PglT88PovuG5b7YYPNI+WDgwtnJA06t2nbgAmfvVpl2IdSJJGOcl5+gZpXUhft3rL6iMeGzd90FGxH37vLA46vGjxahWSUZda6th4cA4hmv/81P4TkxQPluzL3wnHS/5qH1UUlJWDnLB3SbcvX/OBjVFnJyICy4XYwdLth3G0u3naxQMfXW5o/PARMPKXUctn3NbnvqMqquiLIyTsce6VBgPfrJOcsYieVTKFxVGOuy2yK7jBOM9CweXUV4uLPni3P3+Gtz13vm8M9sPFTo6D0w0jJqyEj9lqS0Y48vf5261dD5fZKeh2BBl9O6CzaFTRYRCX2RoUZiUE0bYcei8V1A41ZzspHzBcEJ1VBYOLmPS3C3o/MwCnLbIgBurHFKc4TTWWbU7+t2X0VIt+kppkWSOjZbfTT0fXR3unmxVUj5vUJ+dsHBwGd7MoiwczPGT4lKTsY6RSmhG76uVq8/Z+0htlXC4+73V1kwUAhYOkrHKnzlYjni3k2dRVOhXv8aPYREIr6bYYsGTqtEba4muIM+RU2oT7IULcisuLUcvRbWk9Zw2Ga8hAxYOEvl6/X5c9OxC5ZWwACDTgjnsoO+Lct0R7U57HCklZeUok5AG2ijnSqO7GW0xULLU6KX4JkQiPNmEW+KanAKpleiMYsXXmoWDRLypCbYeLFQ+V26B2ohbL5tyzQmhzzP2SS+XGQ3zQqSJ2OWgIjxpT36HMdPsUyVEG1n8/dbD4Tu5EJc8S6BmtQTlc7BwcCkvzNuGA4pqI6/SuXrOWLPX1NP347M2Si+XGQ3HQhSC13upOAE77SAPfWq+PnW4mhWnHeCeaSeFRSWY/N02lJSVW5oyxigsHFzMvdPPR4/+mJUfcermcKzcdd5f/uNVe/Hhyj1SxrWDQDWC3caeo+Z968Pp+2WoSnYcCr1jznKYMLaaVxbuwNvLduKrdfsx9sO1di8nLCwcXMzWgydx4ZPf4cesfIyeuhqvL86SMq6vUd1qFcKCzXnoOWlR1HrwQJgpNGMVep362eIyfPBLTqXdWjRGWKsM+oGwypNHJSq1Sl6vKxklRq2AhYMCrDSCFpeVV6T2zZGU+dPXMLp8x/l63TLjAzbmBg6kmvTtFhw7XYzDJ9Xe2J1yM9OrdF6avw3PfL25Un3hgiie6rcctM9RIVyAmBv0+Vb87zrlexcOFg4SWa8lG5v6025L5/UajSv7gxsnVHoJmamnb5uyUtpYwVD9v14iuVpXgVYLWW8gjuYUrIrgBYDr3qicCyvcTa+03Lnu10e1HWa1BHW3RDcIRz0sHCSyT/Mg2pan3ltJj7fg+0JJqQVCrX+txHw8Vvhyh/p//MOHa00HE37wixp7zCuLtoc18AbEwqfSjVF6sunTsTiN3AKPc8cFKbWVz+WSjQMLh1hgg0l302h4ab76TJ7ef1QZhHtaW7c3vLA7caYkaHrrfcfkuhQf01RI+46dxYw1oWsLBMTGp1OV6hInuR0bRdgc3R0tLBzijKKSMrzxQ5Y0zyanE02Ng2Dc/f5q3PL2LwEN5LKTEf6oSwj4w7boHQG+tDHyu7hU3c3v6leWYXWUWWejYeSbK/D+it3IOaIufsj7VWSbQxyiD/P3sin3BFLHz43oCTUc90twf/vPD1n458Id+Cwjutz/dtD/5SXK54hER5+53xMNbFbO/JiVH76TjqXbPf1PnI08JUveieh3XYbUVwFQne9L9e7h2W+2YLXCAkjer4+VdiEzsHBQhNeDyFsT4QcJ7qDzJQTOeP3+zdbadQvhvE/0NSuCjqH9W5t94vs8w1g1sBfmRZ7ue01O9A8hd78vJzK7iuK7yYqd1tfhkIk4Lx1cQVW7FxCrrNtTgHV7CxynZXTKU4udKTX0RGLf8O4Ii0vLUaOq8bQFRr8LxyQEqIWqv7AiW85N9+BxtTEWVuZYUkGxpsqV5VWoGt45KOKJLzbiiVkbsV3z/HGKntG7Drvd6r5XULQlENPDeBNFc13mZ+Zhugkbg503tw8VeVXp2WVRlTS34r3+by3dafNKIiPuhIMQwtIgNaufEr4L80TuTXFwKoh+ONzxsshVlBdKz9YIModGIxwen7URE11QKjXQd45v3M5Blo1HNXEnHC55fjHaTphn2XxWqxkfCJPkzusN80uQQLdwx0dDKLuGSs8TL5HUbT4UIgo7c/8JfPhLjrwFWcSufOsEwbHTxej8zHxTOyrGmcSdzcGba+dMcSkSq6s//Qp3RKfolSxk3Cf2ZWMFgEc/2xC2T6igvmv/81PQzxgPq3cfw5niMlfsqJyCW24Fcbdz8GLl0xXgPAeFSFQuRpm1NhczVu815KdvlH3HzuCp2ZtsLZjjFLIOV45wDxWoF0y9GCmFFlU+ZKzHlHAgooeJKJOINhPRI1pbMhEtIqIs7XeS1k5E9DoRZRPRRiLqqRtnjNY/i4jGmDulyLD6qdBpTwuFRaVY5GMUllUf4i+fb8B4BYXgA7F0+2GUlws89Omv+GjlXmzQkvnJznkUiDkO9Z7xzaJbHOJvcfNbP5uaq8gl+nNn4bCbQRAMCwci6grgPgC9AVwM4Foiag9gPIDFQog0AIu19wBwDYA07WcsgLe0cZIBTATQRxtrolegxCuFRSV4a+lOlEt+Cv7Fx0/8vg8yKr0fPXVVROOYNejLMnpP+HIj7npvDR6a8SvWa66a3qXJiIwOxxdrjcUt6FHhwfTNhgPYrTNAh/pTmM0D9vTsTFPHxyfu2N2a2Tl0ArBKCHFGCFEKYBmAGwGMBDBd6zMdwPXa65EAPhAeVgJoQETNAAwFsEgIcUwIUQBgEYBhJtblSCKNL8g5chrdnl2Il+Zvq5QKQZ822yg/ZYceo+BMZCqCj0wW/5mxRk50tnecuRv1wsbzj1dUov6J9ngUkcvBCFeB7dZ3fjE07oB/LsWWAyctKyfLxB5mhEMmgP5E1JCIEgEMB9AKQBMhhPe/NQ9AE+11CwD6u0Ku1hasPaaIVK2kf5J8b8X51N/BvIvs4OmvN5va1SyTIOiC4X1K/ikr8pKbRtM+hAosk8UqE15dw1//EZe/pD4FCRObGBYOQoitAF4CsBDAfADrAZT59BGQuIciorFElEFEGfn56m4wKohENhwuLML3QYy4Rx1W0Wy/BXEK4QikLhEVvyP/2mUddn/Gz9C4Q43BOAtTBmkhxFQhRC8hxBUACgDsAHBIUxdB++292+2HZ2fhpaXWFqw90HxThBDpQoj0lJQUM0uXzs9hisO/smhH2Kft26asDPo0OtNgXp5oxrDKTFZF4USlZQLZhwvxWRSqq80H7KuexjBOxay3UmPtd2t47A2fAJgDwOtxNAbA19rrOQDu1LyW+gI4oamfFgAYQkRJmiF6iNbmKm5/N7wxd9a64DfneZsOYqck99oftgVOTZFf6IzdRxWFrlsLNudh8KvLK6W+DseTX2VWGLS9uKH2dKRYXQOcCU00NcHtxGwU2BdE1BBACYBxQojjRDQZwEwiuhfAHgC/1frOg8cukQ3gDIC7AUAIcYyIJgHwlol6TgihPnzWBnKCpDDYfeQ0/igxMvme9zOQM3mEtPFko9Ktt7is3FDeqDW7j6F7qwYV71XFwQghQBb7NVuRV4mxhk3PDkHdmtUsmcuUcBBC9A/QdhTAwADtAsC4IONMAzDNzFrcwI5DgXXbwTKU2p0cTyaHC4vQuG5NANBujpGdXKCb6Ucr96BdkHKORv9meSfVZhT1cupcqWX/3F5kZHVl4o+4jZAGrK9pECoYKRBbFEYxmyXahIILMs/XoojG5vDxKv9SmU/NzsTt/4ssJiNSoo2LOOmiyOCzLkkRzTiLuBYOnZ6Z7xclbIRIdfnBgsesTuWh57lvthhKOTHi9egizAWAS19cjAlfbooqBkFGBb1IiNY1d8i/lhuaJ7/wnCXR2wxjlrgWDkD0pRsDEWnmzoycgoAlH0NVI3v3x10GVxUZ01bsxoownlaB8Oj2/W+owVJwCAEcPFGET1f77wRCEX1+JmN6pWjlo1E11NWvLMP/zdoIQG1+KyY2sVLTHPfCwUq9/tmSMlz8t4V+7aHsk3+fu1V5Mjnv6Eej1E0fDxBR/f3WwDsxoyk3As0RilApuJ2CN/L97WXuKPrCxCcsHGTIYoPeJzvzT+FEBDe/l+dvi3rsaT/trvT+TLH84u+BXFKDZfm0SgYbzQTrtTmcKy1DbsEZ/BRiN/XfpdmG5mAYNxF39Rx8kbFzWGmw8PnAV5ahTcPEsP3eWR69aum5b7fgnsvbVrz/eGVwdc5Bg9HOFODRItjf0y2eV499tgFzwyQGfHn+dilzuSM3J+MkqidY9zwf9zsHGazOMR6WsefoGdh9mzCaXjvQzuHbjdaUGZXNtrxC5Beew+IgAYQMYzdPjeiEmtUSLJsv7oXDzAw5GULjkU25/mknghlZ//NDlurlmGLtngIM+OdS1+xwmOh4eGCa3UswTb1a1sbHxL1wKCmL/bvBlgMn8fy8rdLHve1/KyPuG2k6cDs5da4U5ywsXnPYIelM4oGHrm6vfI51Tw/GxmeHqJvA4ltV3Nsc4oG/fB6+lnKw1B5OI+uQueI0TkEIgZ8N2qqY6Klqga4+uXZ15XNYSdzvHCJhe14hXpy31XQFND36FNy1qtt/GcIZYYMhhJD6dwnH4FeNBZ85jSe0WAeGiZS0JnUsnc/+u5ILuOPdlXhn+S6pOWoKi867fKY1rittXD0F2nojuXX/Y4ExD5y2E+bh1e+dbU9wIp9LKDHKOIdnru1c8Xpw5yYhehqnR2trqyezcIgAbxCazGyaViTm7DFpkfpJAHxssmxoOPYcdYfKi4lf9G7jU0b3snEl8mDhEAGylSZW59ZRrfY5eroYqePn4rXvdygZ/8p/LFUyLhP7jOzeHLemtwrfUSJEhKznr7F0ThWwcADw7cYD4TtBXjRCYVEpSDfaxlz1tYit4DVWLzEmmfuny6WO9+9RPfDSzRdJHTMSqlkYrKYK95+BBL5aF7AqaQXe/D5/+2Yz/i3hBrhoS14ltZJbKkMxjGyeHN6p4vWNPVqgS/P6Nq6G0cPCAZ7ArZve+hnHz4S+Sc9efwCvSlCdFJeWR10PwQzb8mLD/ZOJLe7o0xr3XdGu4n3N6tZF/zLhYeEA4MCJIqzdU4BvNkSmXjLLwRNFWLU7JiuhKiVW1G+MveRMHoE+bZPtXobjYeFgA/sKjCW6i2e+23QQ172xwu5lMApJtDBvkBVsUhktbQEsHHREWzsAMFaf95sNB/DU7Myoj4tnHvh4nd1LYBTTsVk9y+bqn9ZI+RxW1wqXDQsHHa8sisye8M6ynZifeRAni0rQ06JYAqMcNlixjIlvXrqpm+VzprfxBHk1rVdTyfhPDOuAJvVqAAAubKIm8NSXRnVqWDKPClg4GODF77bh/o/WVYpydiq9X1hs9xIYFzLiouaWz5naqDYAoHpVObcl393BH69qj1V/HQTAuhx2LZJqWTSTfFg4mOCsgupqDOME6tRQm5OzRtUqeGTQhQE/e2JYBylzvHBD8N3PBSnW5ikyiyyBGQ0sHAJwprgUq3Z5MmaeKw3ucnrstPPTUDOME/nmocuRUjewyuXai5pj94vDTc9RPzG4zr994zqWqHxkqed+Hn+1lHGigYVDAB6ftRG3TlmJA8fP4tVFwYPeVNRlZph4IJzOX0Yes3phDML1a6mvWNCxqXkj++a/DbXFdsHCIQBbD3iqmX24cg/eXrYzaL+73ltj1ZIYhpGMzESaqvj2octRW7GKLxgsHHwY/K9l2KUVvnlraXDBwDDxgNe7JxapW9P8TXfUJa3Qsak6z6euLexLJ8KV4HzIOnzK7iUwjCNY/Ocr0bB2dXR/zuOunTN5BMrLBdr9dZ7NK5ODDKP75JusT+pnFaZ2DkT0KBFtJqJMIvqUiGoSUVsiWkVE2UT0GRFV1/rW0N5na5+n6saZoLVvJ6Kh5k6JYRgZXJBSBw0Sq2PmHy7FI4PSAABVqjhfFRMp9/VvF75THGNYOBBRCwB/ApAuhOgKIAHAKAAvAXhVCNEeQAGAe7VD7gVQoLW/qvUDEXXWjusCYBiA/xJRbMXRM4yL6OajyujdNjmo26mbaVpfTbBdrGDW5lAVQC0iqgogEcBBAFcDmKV9Ph3A9drrkdp7aJ8PJI9FaCSAGUKIc0KI3QCyAfQ2uS6GYQwy5rJUu5cAALjuYrWBeFZFSbsVw8JBCLEfwD8B7IVHKJwAsBbAcSGE18czF0AL7XULAPu0Y0u1/g317QGOYRjGYoZ2UVMDOVpeuNH6FB7MecyolZLgeepvC6A5gNrwqIWUQURjiSiDiDLy8/NVTsUwcYtTEsaZMRj/aWBaRP0GdEgxPEesY0atNAjAbiFEvhCiBMCXAPoBaKCpmQCgJQBvmbX9AFoBgPZ5fQBH9e0BjqmEEGKKECJdCJGeksIXlWHcSC0LUnM/Oigy4dC8gXtzH6nGjHDYC6AvESVqtoOBALYAWALgZq3PGABfa6/naO+hff6DEEJo7aM0b6a2ANIArDaxLoZhHIywIO1dpAFuqu0absbwvk0IsYqIZgFYB6AUwK8ApgCYC2AGEf1da5uqHTIVwIdElA3gGDweShBCbCaimfAIllIA44QQ1tXQZBjGUoRVKVEjoE+7hnYvwbGYigIRQkwEMNGneRcCeBsJIYoA3BJknOcBPG9mLQwTb9zUsyW+WJdr9zIcx6hLWoXvFCF92iYHLelbO8ZrXnP6DIZxKfVqVcXT13bGJ/e8ca0FAAAXVElEQVT1sXspUeGr8fnDFe3QTGLMQavkRGljhTJsN4txewWnz2AYl/LE0I6o5cKnV1+10oThnTBheCdp4/fSKsox5uCdA8NIok/bZEvn0wsGKzyAvKycMNCyuYzQ1yI7QuwkEgkMCweGkcQTwzraNvek67taNlfT+jUxaWQXw8df3bGxxNWoo2PTuujeqoFl87VwmJqKhQPDSKJagn3PklaXkbytd2vDx47s7o4ECPMfuQK1a1TFX4d3xBcPXOb3+YWSU3UnOkxFyDYHhpEEBVE0XNSyPjbmnlA6d6M61ZWO74uZQjnDujaVuBL1jL3iAr+2T+/ri4tbya210D8tBQM6NsaU5bukjmsU3jkwjGS6NK9cGrJf+0bK5wwmmNTNZz85k0fYNvelFzREYnV5z9bN69fEX4d3rEgG+NfhHbH1OaXZiMLCwoFhJOF9mBYC+O7h/hXtQ7vIfVKePa4fZo/rV6ktPTVJaUUyt/CHK6Ov0aA6SrpNw/CutS2SaqFqQhXc1LMF3v5dL/z+8na2e6KxcGAYyQgAnZrVw5+ubo+be7WUZtRMb5OEtU8NQvdWDfzGrJZQBVPvusT0HJ/ff2lE/YxqlRrXlVN29KaeLQH4B7z9Nj36ALhXfnuxlDUFIxpPMiLCsK5NHVFUiW0ODCOJmtU8z1pJiZ6spo8N6SB1/D8NTEPDOsFvrk0k3HgvSY3MHdeozWHJX64ydJwv3nvn0C5NMWPNPrx39yUY0MGYF1S1BLXPyP+7Mx39X16idA4VsHBgGEm0b1wXk67vimtsMrg6KGVRUGpLqNsMAGlN6gAABnRsjDVPDkKKpB2Jl6lj0v3a6tSoilPnSqNOJR4uYntgx8b4v2vsc4MOBquVmLhk8Z+vVDLu6L5t0CjE070ZwgV3mU1oJzMnkZcP7zVX1PG9IKqy319+3rYgWzAAQHmAv+Xscf7urDKYetcljqxKx8KBiUsuSKlj9xKiRnUsgwqhZjZn0oCOjfHJ7/1zR6nWyYsAkrZtozpo0zARk2+KvkKdU6rrRQOrlRhGAquftD+lhInQA0fTvrH1gjzQziGhCmHZ4wMMjdeknrzEglbBOweGkUDD2qGfui+3INahWkIV/MuE580lknNDpTWug6pVzN9ikmurD/C7sEllARRo52AGO1OrGIWFA8NYgFV5c27UXDyN0O8CuQnrZv7hUjRINF+PuqpibyIA+Hrc5dBrqgLtHMxgph62XbBwYBgLsDugKRyXpCYpuQk3SLQ2rYdRalVPQLeW52NHyhWUq7MyiZ8MWDgwEWFl1k8V2B1TdE+/tpbN9XCIAjXBqFvT/BO+nseHdkCSBeogmfRopVY4XNRSbi4m1bBwYCJidN82di/BFDKrgxmhdQQpFGRxXffo00GcPFsidQ3jBrSXOl4s4DZ/ARYOTMRsmzQMw7u5K6Oml3FXnb9ZtUr26P9bJsmzAyTYvTXRUcWA21LGngJp8ydJsDPYwYNXn/+OqNg5mMlkawcsHJiIqVktAZe3T1E+T7cW8rff13VvXmEUHNzJI+Dev9tcgJZTCRUU1q5R7YDtMvXhCx8NHGAos060CvRxHgpkg+tg4cDYTsemdSslJ7upZwssN+hPHoya1RIw58F+qFUtAXf3SwVgj/+8UTKeGhRx31CeMcHcQmXuooIJpw/v9Q9mcyrpbeSXfHXZxoGFAxMdqQp05zWqJeCPV50vqNIqOVGJjr5dSh1snTRMuv0h0htr5t+GGp5DlivkR7/vg0kju+CWXpVdXq1wtZUhjGtpQl41Kr5/gdR93gcVJ8LCgQmKN+Rfb2dQYdjt2zYZXTVPjisuTMHATu5KNfD9Y5HlaTJzg68ZRdrncOOMvjTVL2J3gMK6zpHUM4iUZg1q4qKW7nIJ9RJo4zDxN8ZrcauGhUMM8OTwTkrGbZXk+adW7Z/9+NAOGNChMabdlY5pAbJhqkRGsXtZN24r8X2INbJziPR7Mev+y/BxgPxIRpBdOCkQqmqBs1rJ4Qyz4Mvl5TeKK0x5ue+Kdtg2SW5Jwbo1qgb8MqtI/uYNvrq6YxNLomH1PH+Ds+I3Hh/aAf+9o6fyeVr77ACN7AjX7zseUb+UujWklUr9i+QaGb7kTB6BrOeHKxmbvZUczqWSUwSEQq9HV03NaglSdxDlQuD2Pm2QlFgNIy46L+Sa1KuJ/9zWQ9o8dmN17eVwjBvQHsO7NavU1lBBMNnNOpvDoE7Gdk9Wfr8XPnoFXr31Yke5DEeL21Yed8LhzkutC+ZKlJwyoU+AxGh6n/J7L5cXhdu3XUO0bVQbvz4zxE/l8JuLm+OxwRdKmysYgyywPTSpp6b2ghHWPOnvkfTm7T0r1aOWBRFVCIVbL2ltaAzVKcT1XNikLm7oYTxvlCNwmXQwfHWJqAMRrdf9nCSiR4gomYgWEVGW9jtJ609E9DoRZRPRRiLqqRtrjNY/i4jGyDixEOtWOXwl2jQ871M+wudp0Agtk85v/Z8a4dkl6Ctrycxx37BO6KfVSMtJmsFbdtMswXz7Ac/3QcWTuRH0LqDZz1+DbZOGYcRFzdBYUrrnRmGuabToU24MVGjQjhVqVHWXbcrwf58QYrsQorsQojuAXgDOAPgKwHgAi4UQaQAWa+8B4BoAadrPWABvAQARJQOYCKAPgN4AJnoFSixwsWa0e12CKua5kR7Phhdu6KbcMBcuCKhtiBuuk7ixRwu8GUaHf98V7UJ+HorrFNmVqiZUkW7oDpY+22h6av0u3EjKjnjjepf9jWTlkR0IYKcQYg8RjQRwldY+HcBSAP8HYCSAD4Tnm7iSiBoQUTOt7yIhxDEAIKJFAIYB+FTS2mzlg7t7Izv/lBRdae0aVZEzeQQA4NDJIgBA8/pq/NP/FCZ5myx1zDPXdpYyTjD+dWt3peO7SQVev5ZvWgtzi6+WUAXz/tQf905fg/5p6iPn3U41i50tzCJrtaNw/mbeRAhxUHudB8CrOG4BYJ/umFytLVh7TFA/sRp6tfFshOpKzOnepF5NvHF7D7w9upe0Mb08f0PXsN4rstRz90i0k/hSt2Zkf28zZzKyu5yv6uQboy89GYpApTV7tpHvkty5eT38MmGgJQV5GGsxLRyIqDqA6wB87vuZtkuQlqWEiMYSUQYRZeTn5xsex66nvYnXyQ14ufai5n7/lO+M7mW64PrAjtYEoYXT9ZsVQI8MUm80l6X6GdXbmFE4GJe1b1ShhvTy5IjKu7Tf9/cI5p5t5GhxL0lNwuND1bqaMtYhY+dwDYB1QohD2vtDmroI2u/DWvt+AK10x7XU2oK1+yGEmCKESBdCpKekGN/G2uVv3LFpXeVzDO3SNKDXS6SM7tsGTS1KkFYjjLdLhybG0y3UrVEV90SYmsCMys+qCm9G6Nuustu2b4R233YNkTN5RKWEc2b4/P7LOFV3CHxvO07fbckQDrehsn1gDgCvx9EYAF/r2u/UvJb6AjihqZ8WABhCREmaIXqI1qaMMtk1AH0Y2qUJ/iZ5lxCLfBQmavaBq4zfaN76Xa+IHwLu6BPevXn2uH74e4CCR9Hm4Akkh7yeZ7K5sIn6hxEmcny/j07fZZlSghNRbQCDAfxB1zwZwEwiuhfAHgC/1drnARgOIBsez6a7AUAIcYyIJgFYo/V7zmucVkXdGlVReK5UydjtGtXGO6MDp4BwQ4DkaAvjQNqlhN4ZmHmirxOhvQGIrIRn91YN0L1VAzw1O9PwmgBPZHLO0TOV2qyMF2CcwT9vubhSIKITMfWtFEKcFkI0FEKc0LUdFUIMFEKkCSEGeW/0wsM4IcQFQohuQogM3THThBDttZ/3zKwpEh4aqG7r+6oi75hZ91+qZFxfrHraNGsXCYdMORwsPsBIzqlAgX1NJcUxMM5G/50c6QK31rh8ZFHpUhaqClfjupHfBHxdPNOjDDq74kL1roULH73C8LGPDIq+zrET8KYsT65dHe8Y8BQLpNBUWVfCu/lSlUyOMYYb3Fqdv0IFqDQ5iBDOWdE8Lfdum4xnf2M8BuB2yd4vgTCzy4hEz+9Elj4+ADmTR2Dd04P90l4bRWW+IK9dR0UKDiY63KBW1hOfwkGxQdoMkzT3wxYNaqG/iad/J3vRWIHM6GKZ3ls39vSPi9CnWZHNZRc0Qs7kEWjfmI3TduO0JI/hiEvhYGVm1mgZfWkqciaPQFLt6kg1cdPo1jJ4HWYrcwmtGH+1ZXPpMfqU1irZX6jKjP7t0lx+fWyGUUFcCoeuLeojZ/KIilQUTiWhCqF1ciJevfViqeP+45aLKsa/qed5j4lJI+W737ptB7PwkSvxIPvqMwpgtVIcc0OPFujcrF7IPjtfGB6V6+LyJwZIT1XcVXt6/cfNF6FTs/Pqhs7NQ689EH82kLpbVj1kPQ+HyQUVKbWqJ+DRwRfikUFpeMDCegVM7FOvpm9uK2fDwkES3Vs1wKu3dg9bySyhCuHOvvYaYxvXq4mcySNwY8+WlaJou7aIXuWRGiQ768e/74NvH7ocAPy8ehZE4eX0z1si2zX5ugZGayz+ZcLV+OKBywB4rtEjgy6sEGKyH/h801ow8UEk8TROgoWDDbRxULprvUAwkm/+4gDF3m/u1RL92jeqGNs3vXg0No+be7WMSP3n62Lgn4E0NM3q16pIkKiaHq1iJiM9E8PEvXAY3s26mtJeRl3SKnwnC1n++AAs+ctVho5t3TARt/ep7DbbJmxGV0NThcRgSYKQ3JLeEl1b1JMeNe423TMTn8S9cPjHzRfjSgkBY+Ov6Rhx32ABMIFy95ihcYC4iht6+LtStm6YaKp4TzRV7maM7Su9IpYqx4LGdWvi24f6o5nkmhm1FdhcGPegL+3rZOJeONSuURXT7+ltaoycySP8MmAaoYHkL02gJ9TJN8mtGxCIYA/xrZMTpfydAOCN2ytX1nN6hks9bqmix8jn7d/1xDeaLc7p8COMSR4z4K0TDNWlPwFr6tgGivhd9/RgaTWhv33ocqT5pPN2k3AAgM7N6mHLwZP4bGxfu5fCWMiwruZryVtF3O8cvBhNnztaoueR6nwr0RppI8U3gviefv7V3ZJrV0didTnPIvVqVgso5GZoN9qereVXPJONd1cn62/CMLJh4aAxoENjQ8clueiJ9YsH1GR2vSClTiWjtGqXvWD5q7xeUK3DGMSdgFc4hMrFxTB2wo8tMUxSYnUcOnmu4r3K/DrebKX39VdXExoAeqcmV9z8WybVQm7B2YrP0prUxdQx6Y5Oj+IlRau+5obsnEx8wsJBo3aN80+7RIFdI7c8NxQ/Zh3BV+v2Y/7mPKnz73phuNTxAOC9uy/B91sP4+nZmbiqg/oU3oAal1IAWPqXqzBtxW48+5suFRW1vh7XD/t0wgEABgaol+BE/vXb7pi/OQ+dwkTUM4xdsHDQ0GfG/O7h/hj22o9+fRKrV8XQLk3Ro1UD6cKhioK0zc3q18Lovm0wuFMT6Z5QvrRt5DEQ+xqKZZHaqDaeG1nZ1bdhnRpoKKn+sdUk1a6O2yxIq84wRmHhoGP3i8MhBHC6OHQJ0QaJ7rEzAHJTTgdjcOcmmPNgP3QzkIKDYRjnwcJBBxGBCKhaJbQemGv+BuaiAKk0GIZxJywcAmBFgqyv/ngZSsoEWicnIowsYhiGsRwWDjbRozUnX2MYxrnwMyvDMAzjBwsHhmEYxg8WDkF49870kJ+vfWoQfn16sEWrYRiGsRa2OQShY7PK0cS+YQhu9a9nGIaJBN45BKG2T0K06wPUQWAYholVWDgEIal2dcwe1w+TuN4vwzBxCAuHEHRv1QD1tDTX1TlBGsMwcYSpOx4RNSCiWUS0jYi2EtGlRJRMRIuIKEv7naT1JSJ6nYiyiWgjEfXUjTNG659FRGPMnpRMhndrhvuvvAAThneyeykMwzCWYfZx+N8A5gshOgK4GMBWAOMBLBZCpAFYrL0HgGsApGk/YwG8BQBElAxgIoA+AHoDmOgVKE6gWkIVjL+mo7JCOQzDME7EsHAgovoArgAwFQCEEMVCiOMARgKYrnWbDuB67fVIAB8IDysBNCCiZgCGAlgkhDgmhCgAsAjAMKPrYhiGYcxjZufQFkA+gPeI6FciepeIagNoIoQ4qPXJA+BNsN8CwD7d8blaW7B2hmEYxibMCIeqAHoCeEsI0QPAaZxXIQEAhBACkFcHkYjGElEGEWXk5+fLGpZhGIbxwYxwyAWQK4RYpb2fBY+wOKSpi6D9Pqx9vh9AK93xLbW2YO1+CCGmCCHShRDpKSnWVDZjGIaJRwwLByFEHoB9RNRBaxoIYAuAOQC8HkdjAHytvZ4D4E7Na6kvgBOa+mkBgCFElKQZoodobQzDMIxNmE2f8RCAj4moOoBdAO6GR+DMJKJ7AewB8Fut7zwAwwFkAzij9YUQ4hgRTQKwRuv3nBDimMl1MQzDMCYgoaoivGLS09NFRkaG3ctgGIZxFUS0VggROrMoOEKaYRiGCYBrdw5ElA+P2soIjQAckbgcp8LnGVvEw3nGwzkC9p5nGyFEWI8e1woHMxBRRiTbKrfD5xlbxMN5xsM5Au44T1YrMQzDMH6wcGAYhmH8iFfhMMXuBVgEn2dsEQ/nGQ/nCLjgPOPS5sAwDMOEJl53DgzDMEwI4ko4ENEwItquFRwaH/4I+yGiVkS0hIi2ENFmInpYa5dWVImIehHRJu2Y14mIrD9TgIgStAy/32rv2xLRKm1dn2mR+CCiGtr7bO3zVN0YE7T27UQ0VNfumGuvukiWE64nET2qfV8ziehTIqoZK9eTiKYR0WEiytS1Kb9+weZQhhAiLn4AJADYCaAdgOoANgDobPe6Ilh3MwA9tdd1AewA0BnAywDGa+3jAbykvR4O4DsABKAvgFVaezI8KU6SASRpr5O0z1ZrfUk79hqbzvUxAJ8A+FZ7PxPAKO312wAe0F7/EcDb2utRAD7TXnfWrmsNeFLK79Suu6OuPTx1Tn6vva4OoEEsXU94Uu7vBlBLdx3vipXrCU8dm54AMnVtyq9fsDmUnadd/yBW/wC4FMAC3fsJACbYvS4D5/E1gMEAtgNoprU1A7Bde/0OgNt0/bdrn98G4B1d+ztaWzMA23TtlfpZeF4t4akceDWAb7V/jCMAqvpeP3gSM16qva6q9SPfa+rt56RrD6C+duMkn/aYuZ44X6MlWbs+38JT1CtmrieAVFQWDsqvX7A5VP3Ek1rJ9UWFtO12DwCrIK+oUgvttW+71bwG4AkA5dr7hgCOCyFKA6yr4ly0z09o/d1QUEp1kSzbr6cQYj+AfwLYC+AgPNdnLWLzenqx4voFm0MJ8SQcXA0R1QHwBYBHhBAn9Z8Jz6OEa93OiOhaAIeFEGvtXosFWF4ky2o0XfhIeARhcwC1EUelf624flbMEU/CIeKiQk6DiKrBIxg+FkJ8qTXLKqq0X3vt224l/QBcR0Q5AGbAo1r6Nzx1xr1p5fXrqjgX7fP6AI5CQkEpC1BdJMsJ13MQgN1CiHwhRAmAL+G5xrF4Pb1Ycf2CzaGEeBIOawCkaR4T1eExfM2xeU1h0TwVpgLYKoT4l+4jKUWVtM9OElFfba47dWNZghBighCipRAiFZ7r8oMQ4g4ASwDcrHXzPUfvud+s9Rda+yjN+6UtgDR4jHuOufZCcZEsJ1xPeNRJfYkoUVuD9xxj7nrqsOL6BZtDDVYacez+gcdzYAc8ng5P2r2eCNd8OTzbx40A1ms/w+HRyS4GkAXgewDJWn8C8KZ2jpsApOvGugeeYkvZAO7WtacDyNSOeQM+xlKLz/cqnPdWagfPzSAbwOcAamjtNbX32drn7XTHP6mdx3bovHScdO0BdAeQoV3T2fB4q8TU9QTwNwDbtHV8CI/HUUxcTwCfwmNLKYFnJ3ivFdcv2ByqfjhCmmEYhvEjntRKDMMwTISwcGAYhmH8YOHAMAzD+MHCgWEYhvGDhQPDMAzjBwsHhmEYxg8WDgzDMIwfLBwYhmEYP/4flpinpfLxCPgAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "file1 = \"./data/Hourly_load_data_(2005-2015).xlsx\"\n",
    "file2 = \"./data/Hourly-load-data-for-2016.xlsx\"\n",
    "\n",
    "df1 = pd.read_excel(file1, sheet_name=0)\n",
    "df2 = pd.read_excel(file2, sheet_name=0)\n",
    "\n",
    "# 选出第二列\n",
    "load1 = df1[\"Alberta Internal Load\"]\n",
    "load2 = df2[\"Alberta Internal Load\"]\n",
    "\n",
    "# 合并 05-15 年以及 16 年的数据\n",
    "load = pd.concat([load1, load2], ignore_index=True)\n",
    "plt.figure()\n",
    "plt.plot(load)\n",
    "\n",
    "\"\"\"\n",
    "数据做标准化处理\n",
    "\"\"\"\n",
    "load_norm = (load - np.mean(load)) / np.std(load)\n",
    "\n",
    "\"\"\"\n",
    "增加数据维度\n",
    "\"\"\"\n",
    "load_norm = load_norm[:, np.newaxis]\n",
    "print('load_norm shape: {}'.format(load_norm.shape))\n",
    "print(load_norm[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "#设置常量\n",
    "time_step=20      #time step\n",
    "rnn_unit=64       #hidden layer units\n",
    "batch_size=64     #batch size\n",
    "input_size=1      #input size\n",
    "output_size=1     #output size\n",
    "lr=0.0001         #learning rate\n",
    "\n",
    "train_x, train_y=[], []\n",
    "for i in range(len(load_norm)- time_step - 1):\n",
    "    x=load_norm[i: i + time_step]\n",
    "    y=load_norm[i + 1: i + time_step + 1]\n",
    "    train_x.append(x.tolist())\n",
    "    train_y.append(y.tolist())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = np.array(train_x)\n",
    "train_y = np.array(train_y)\n",
    "test_x = train_x[100000:]\n",
    "test_y = train_y[100000:]\n",
    "train_x = train_x[:100000]\n",
    "train_y = train_x[:100000]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_x shape: (100000, 20, 1)\n",
      "train_y shape: (100000, 20, 1)\n",
      "test_x shape: (5171, 20, 1)\n",
      "test_y shape: (5171, 20, 1)\n",
      "[[[-0.85610657]\n",
      "  [-0.99451862]\n",
      "  [-1.14643429]\n",
      "  [-1.2432102 ]\n",
      "  [-1.32423189]\n",
      "  [-1.3366102 ]\n",
      "  [-1.23308249]\n",
      "  [-1.17906803]\n",
      "  [-1.08454272]\n",
      "  [-1.03165356]\n",
      "  [-0.9776391 ]\n",
      "  [-0.7030656 ]\n",
      "  [-0.62204391]\n",
      "  [-0.62992102]\n",
      "  [-0.63329693]\n",
      "  [-0.57365596]\n",
      "  [-0.26307282]\n",
      "  [ 0.09139707]\n",
      "  [ 0.00699948]\n",
      "  [-0.09315233]]\n",
      "\n",
      " [[-0.99451862]\n",
      "  [-1.14643429]\n",
      "  [-1.2432102 ]\n",
      "  [-1.32423189]\n",
      "  [-1.3366102 ]\n",
      "  [-1.23308249]\n",
      "  [-1.17906803]\n",
      "  [-1.08454272]\n",
      "  [-1.03165356]\n",
      "  [-0.9776391 ]\n",
      "  [-0.7030656 ]\n",
      "  [-0.62204391]\n",
      "  [-0.62992102]\n",
      "  [-0.63329693]\n",
      "  [-0.57365596]\n",
      "  [-0.26307282]\n",
      "  [ 0.09139707]\n",
      "  [ 0.00699948]\n",
      "  [-0.09315233]\n",
      "  [-0.1584198 ]]]\n"
     ]
    }
   ],
   "source": [
    "print('train_x shape: {}'.format(train_x.shape))\n",
    "print('train_y shape: {}'.format(train_y.shape))\n",
    "print('test_x shape: {}'.format(test_x.shape))\n",
    "print('test_y shape: {}'.format(test_y.shape))\n",
    "print(train_x[:2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义 placeholder\n",
    "X=tf.placeholder(tf.float32, [None,time_step,input_size])\n",
    "Y=tf.placeholder(tf.float32, [None,time_step,output_size])\n",
    "\n",
    "# 输入层、输出层权重、偏置\n",
    "weights = {\n",
    "    'in': tf.Variable(tf.random_normal([input_size, rnn_unit])),\n",
    "    'out': tf.Variable(tf.random_normal([rnn_unit, input_size]))\n",
    "}\n",
    "biases = {\n",
    "    'in': tf.Variable(tf.constant(0.1, shape=[rnn_unit,])),\n",
    "    'out': tf.Variable(tf.constant(0.1, shape=[output_size,]))\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def lstm(batch):  #参数：输入网络批次数目\n",
    "    w_in = weights['in']\n",
    "    b_in = biases['in']\n",
    "    # 以下三行代码的运算过程如下：\n",
    "    #首先将tensor X转成2维进行计算\n",
    "    #[64, 20, 1] => [64 * 20, 1]\n",
    "    # [64*20, 1] * [1, 64] => [64*20, 64]\n",
    "    # [64*20, 64] => [64, 20, 64]\n",
    "    # 整体是计算了输入\n",
    "    input=tf.reshape(X,[-1, input_size])\n",
    "    input_rnn=tf.matmul(input,w_in)+b_in\n",
    "    input_rnn=tf.reshape(input_rnn,[-1,time_step,rnn_unit])  #将tensor转成3维，作为lstm cell的输入\n",
    "    cell=tf.nn.rnn_cell.BasicLSTMCell(rnn_unit)\n",
    "    init_state=cell.zero_state(batch,dtype=tf.float32)\n",
    "    #output_rnn是记录lstm每个输出节点的结果，final_states是最后一个cell的结果\n",
    "    output_rnn,final_states=tf.nn.dynamic_rnn(cell, input_rnn,initial_state=init_state, dtype=tf.float32)\n",
    "    #[64, 20, 64] => [64 * 20, 64]\n",
    "    #[64 * 20, 64] * [64, 1] => [64*20, 1]\n",
    "    # [64*20, 1] + bias => [64 * 20 ,1]\n",
    "    output=tf.reshape(output_rnn,[-1,rnn_unit]) #作为输出层的输入\n",
    "    w_out=weights['out']\n",
    "    b_out=biases['out']\n",
    "    pred=tf.matmul(output,w_out)+b_out\n",
    "    return pred,final_states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_lstm():\n",
    "    global batch_size\n",
    "    pred,_=lstm(batch_size)\n",
    "    #损失函数\n",
    "    loss=tf.reduce_mean(tf.square(tf.reshape(pred,[-1])-tf.reshape(Y, [-1])))\n",
    "    train_op=tf.train.AdamOptimizer(lr).minimize(loss)\n",
    "    saver=tf.train.Saver()  # 保存变量\n",
    "    save_dir = './model'  # 目录\n",
    "    if not os.path.exists(save_dir):\n",
    "        os.makedirs(save_dir)\n",
    "    save_path = os.path.join(save_dir, 'best_validation')  # 最佳验证结果保存路径\n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        #重复训练10000次\n",
    "        start = 0\n",
    "        end = start + batch_size\n",
    "        best_loss = np.inf\n",
    "        for step in range(10000):\n",
    "            if(end<len(train_x)):\n",
    "                _,loss_=sess.run([train_op,loss],feed_dict={X:train_x[start:end],Y:train_y[start:end]})\n",
    "                start+=batch_size\n",
    "                end=start+batch_size\n",
    "            else:\n",
    "                start = 0\n",
    "                end = start + batch_size\n",
    "            if step%500==0:\n",
    "                start_val = 0\n",
    "                end_val = start_val + batch_size\n",
    "                total_loss = 0\n",
    "                num = 0\n",
    "                while(end_val < len(test_x)):\n",
    "                    loss_temp = sess.run([loss],feed_dict={X:test_x[start_val:end_val],Y:test_y[start_val:end_val]})\n",
    "                    loss_temp = tf.squeeze(loss_temp)  # tf.squeeze() 删除所有维度为1的维度\n",
    "                    total_loss += sess.run(loss_temp)\n",
    "                    start_val += batch_size\n",
    "                    end_val = start_val + batch_size\n",
    "                    num += 1\n",
    "                loss_val = total_loss / num\n",
    "                print('setp: {}      train loss: {}      validation loss: {}'.format(step, loss_, loss_val))\n",
    "                if loss_val < best_loss:\n",
    "                    print('best_loss: {} -> {}, save model.'.format(best_loss, loss_val))\n",
    "                    best_loss = loss_val\n",
    "                    saver.save(sess=sess, save_path=save_path)\n",
    "            step += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "setp: 0      train loss: 7.9421210289001465      validation loss: 7.800604730844498\n",
      "best_loss: inf -> 7.800604730844498, save model.\n",
      "setp: 500      train loss: 0.004870368167757988      validation loss: 0.3511730349389836\n",
      "best_loss: 7.800604730844498 -> 0.3511730349389836, save model.\n",
      "setp: 1000      train loss: 0.004037308506667614      validation loss: 0.07431437058839947\n",
      "best_loss: 0.3511730349389836 -> 0.07431437058839947, save model.\n",
      "setp: 1500      train loss: 0.002303384244441986      validation loss: 0.055476910434663296\n",
      "best_loss: 0.07431437058839947 -> 0.055476910434663296, save model.\n",
      "setp: 2000      train loss: 0.001969343051314354      validation loss: 0.05432232823222875\n",
      "best_loss: 0.055476910434663296 -> 0.05432232823222875, save model.\n",
      "setp: 2500      train loss: 0.0013308876659721136      validation loss: 0.054471371392719445\n",
      "setp: 3000      train loss: 0.0007912603323347867      validation loss: 0.04746811070945114\n",
      "best_loss: 0.05432232823222875 -> 0.04746811070945114, save model.\n",
      "setp: 3500      train loss: 0.0006152402493171394      validation loss: 0.047773530287668106\n",
      "setp: 4000      train loss: 0.00035344838397577405      validation loss: 0.04957545504439622\n",
      "setp: 4500      train loss: 0.0006493263645097613      validation loss: 0.04814686665777117\n",
      "setp: 5000      train loss: 0.0002610427909530699      validation loss: 0.04512298428453505\n",
      "best_loss: 0.04746811070945114 -> 0.04512298428453505, save model.\n",
      "setp: 5500      train loss: 0.0004519816138781607      validation loss: 0.046056712488643826\n",
      "setp: 6000      train loss: 0.00023932449403218925      validation loss: 0.04570641254540533\n",
      "setp: 6500      train loss: 0.00014523147547151893      validation loss: 0.0453616946702823\n",
      "setp: 7000      train loss: 0.00011952155909966677      validation loss: 0.04618830240797252\n",
      "setp: 7500      train loss: 0.0006494203116744757      validation loss: 0.04371175009291619\n",
      "best_loss: 0.04512298428453505 -> 0.04371175009291619, save model.\n",
      "setp: 8000      train loss: 0.0001230952184414491      validation loss: 0.04550474444404244\n",
      "setp: 8500      train loss: 0.00010346870112698525      validation loss: 0.045836743572726844\n",
      "setp: 9000      train loss: 0.00016646142466925085      validation loss: 0.045050444058142604\n",
      "setp: 9500      train loss: 8.016314677661285e-05      validation loss: 0.04673052036669105\n"
     ]
    }
   ],
   "source": [
    "train_lstm()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1]\n",
      "1\n"
     ]
    }
   ],
   "source": [
    "# 验证 tf.squeeze()\n",
    "# sess = tf.Session()\n",
    "# a = tf.Variable([1])\n",
    "# sess.run(a.initializer)\n",
    "# print(sess.run(a))\n",
    "# b = tf.squeeze(a)\n",
    "# print(sess.run(b))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
